# model settings
_base_ = [
    '../_base_/datasets/ade20k.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
training_steps = 80000
data_preprocessor = dict(
    type='SegDataPreProcessor',
    size = (512, 512),
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255)
model = dict(
    type='EncoderDecoder',
    data_preprocessor=data_preprocessor,
    pretrained=None,
    backbone=dict(
        type='RePaSwinTransformer',
        init_cfg=dict(type='Pretrained', checkpoint="checkpoints/repaswin_base_acc_82.pth"),
        img_size=224, 
        patch_size=4, 
        in_chans=3,
        embed_dim=128, 
        depths=[2, 2, 18, 2], 
        num_heads=[4, 8, 16,32],
        window_size=7, 
        mlp_ratio=4., 
        qkv_bias=True, 
        qk_scale=None,
        ape=False, 
        patch_norm=True,
        use_checkpoint=False, 
        fused_window_process=False,
        channel_idle=True,
        drop_path_rate=0.1,
        feature_norm="BatchNorm",
        out_indices=(0, 1, 2, 3)),
    decode_head=dict(
        type='UPerHead',
        in_channels=[128, 256, 512, 1024],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=512,
        dropout_ratio=0.1,
        num_classes=150,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=512,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=150,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
    _delete_=True,
    type='OptimWrapper',
    optimizer=dict(
        type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.05),
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))

param_scheduler = [
    dict(
        type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=8000),
    # dict(
    #     type='PolyLR',
    #     eta_min=0.0,
    #     power=1.0,
    #     begin=10000,
    #     end=training_steps,
    #     by_epoch=False,
    # )
    dict(
        type='CosineAnnealingLR',
        begin=8000,
        end=training_steps,
        eta_min=1e-7,
        by_epoch=False
    )
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=16, num_workers=24)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
train_cfg = dict(max_iters=training_steps, type='IterBasedTrainLoop', val_interval=4000)